{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c9e9f75f",
   "metadata": {},
   "source": [
    "### 线性回归回顾"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86755e76",
   "metadata": {},
   "source": [
    "首先简单回顾一下，线性回归是一种机器学习算法，它用于预测连续的输出变量（也称为目标变量或响应变量），其中输出变量与输入变量之间具有线性关系。\n",
    "\n",
    "在线性回归中，我们假设输出变量与输入变量之间具有如下线性关系：\n",
    "\n",
    "$$ y = wx + b $$\n",
    "\n",
    "其中，$y$ 是输出变量，$x$ 是输入变量，$w$ 和 $b$ 是参数，$w$ 称为权重或斜率，$b$ 称为偏差。\n",
    "\n",
    "我们的目标是通过训练样本，找到最优的参数 $w$ 和 $b$，使得模型能够准确地预测新的样本。为了找到最优的参数，我们需要定义一个损失函数，表示模型的预测精度。在线性回归中，常用的损失函数是均方误差（mean squared error, MSE）：\n",
    "\n",
    "$$ MSE = \\frac{1}{n}\\sum_{i=1}^{n}(y_{i} - \\hat{y}_{i})^{2} $$\n",
    "\n",
    "其中，$y_{i}$ 是第 $i$ 个样本的真实输出值，$\\hat{y}_{i}$ 是模型预测的输出值，$n$ 是样本数量。\n",
    "\n",
    "可以使用梯度下降算法来最小化损失函数，即不断更新参数 $w$ 和 $b$，使得损失函数的值越来越小。\n",
    "\n",
    "具体来说，我们需要计算损失函数对 $w$ 和 $b$ 的梯度，然后按照如下公式更新参数：\n",
    "\n",
    "$$ w = w - \\alpha \\frac{\\partial L}{\\partial w} $$\n",
    "$$ b = b - \\alpha \\frac{\\partial L}{\\partial b} $$\n",
    "\n",
    "其中，$\\alpha$ 是学习率，$L$ 是损失函数，$\\frac{\\partial L}{\\partial w}$ 和 $\\frac{\\partial L}{\\partial b}$ 分别表示损失函数对 $w$ 和 $b$ 的梯度。\n",
    "\n",
    "我们可以重复进行这个过程，直到损失函数的值足够小，或者达到了设定的最大迭代次数为止。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f407593e",
   "metadata": {},
   "source": [
    "### 代码实现"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6fdce9d0",
   "metadata": {},
   "source": [
    "#### 数据生成"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9326cdd6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "# 设置随机数种子，使得每次运行代码生成的数据相同\n",
    "np.random.seed(42)\n",
    "\n",
    "# 生成随机数据，w为2，b为1\n",
    "x = np.random.rand(100, 1)\n",
    "y = 1 + 2 * x + 0.1 * np.random.randn(100, 1)\n",
    "\n",
    "# 将数据转换为 pytorch tensor\n",
    "x_tensor = torch.from_numpy(x).float()\n",
    "y_tensor = torch.from_numpy(y).float()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dcbde546",
   "metadata": {},
   "source": [
    "#### 设置超参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b260e02c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 设置超参数\n",
    "learning_rate = 0.1\n",
    "num_epochs = 1000"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c6c5823",
   "metadata": {},
   "source": [
    "#### 初始化参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "be6bc47e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 初始化参数，可以使用常数、随机数或预训练等\n",
    "w = torch.randn(1, requires_grad = True)\n",
    "b = torch.zeros(1, requires_grad = True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e6c94f4",
   "metadata": {},
   "source": [
    "#### 开始训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "edbed799",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "w: tensor([1.9540], requires_grad=True)\n",
      "b: tensor([1.0215], requires_grad=True)\n"
     ]
    }
   ],
   "source": [
    "# 开始训练\n",
    "for epoch in range(num_epochs):\n",
    "    # 计算预测值\n",
    "    y_pred = x_tensor * w + b\n",
    "\n",
    "    # 计算损失\n",
    "    loss = ((y_pred - y_tensor) ** 2).mean()\n",
    "\n",
    "    # 反向传播\n",
    "    loss.backward()\n",
    "\n",
    "    # 更新参数\n",
    "    with torch.no_grad():\n",
    "        w -= learning_rate * w.grad\n",
    "        b -= learning_rate * b.grad\n",
    "\n",
    "        # 清空梯度\n",
    "        w.grad.zero_()\n",
    "        b.grad.zero_()\n",
    "\n",
    "# 输出训练后的参数，与数据生成时设置的常数基本一致\n",
    "print('w:', w)\n",
    "print('b:', b)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da9e98a3",
   "metadata": {},
   "source": [
    "#### 可视化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "2e1d414b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8+yak3AAAACXBIWXMAAAsTAAALEwEAmpwYAAAh40lEQVR4nO3de5xcZZ3n8c+vO5XuCmAXS1BJS+ioEQhkJUwv4PbuInjh5pAGwUAEnFkWvI2jLyQab7FhLomTUVZXR2QHF1GawQHSiQQXL40v3LwM0iHckogDJCM0IOHSjZCupLvz7B9V1amqc6rqVNWp+/f9euVl9zmnznmOCb9++nl+z+8x5xwiItL42mrdABERCYcCuohIk1BAFxFpEgroIiJNQgFdRKRJzKrVg+fOnet6enpq9XgRkYa0ZcuWF51zh/udq1lA7+npYWRkpFaPFxFpSGb277nOachFRKRJKKCLiDQJBXQRkSahgC4i0iQU0EVEmkTNslxERFrN0NZR1t7zOM+OTTAvFmXFGUfTv6Q7tPsroIuIVMHQ1lG+cOejTExOAzA6NsEX7nwUILSgriEXEZEqWHvP4zPBPGVicpq19zwe2jMU0EVEquDZsYmijpdCAV1EpArmxaJFHS9FwYBuZp1m9lsze9jMtpnZNT7XdJjZbWb2hJndb2Y9obVQRKQJrDjjaKKR9oxj0Ug7K844OrRnBOmh7wVOd869EzgBONPMTsm65nLgFefc24HrgK+F1kIRkSbQv6Sb1ecvpjsWxYDuWJTV5y+ubpaLS2w6+lry20jyT/ZGpEuBgeTXtwPfNjNz2rBURGRG/5LuUAN4tkBpi2bWDmwB3g58xzl3f9Yl3cDTAM65KTMbBw4DXsy6z5XAlQDz588vr+UiInWs0jnnfgJNijrnpp1zJwBvAU4ys+NLeZhz7gbnXK9zrvfww33L+YqINLxUzvno2ASOAznnQ1tHK/rcorJcnHNjwL3AmVmnRoEjAcxsFtAFvBRC+0REGk41cs79BMlyOdzMYsmvo8D7gN9lXbYB+Ejy6wuAYY2fi0irqkbOuZ8gPfQjgHvN7BHgAeDnzrm7zOxaMzs3ec2NwGFm9gRwFbCyMs0VEal/uXLLT33D8/Dd/wLPPVyR5wbJcnkEWOJzfFXa13HgwnCbJiLSmFaccXRG3ZbZTPLzjs9z1L7n4Y/Aq8/CEe8M/bkqziUiTaUW2SXZUs9be8/jnPWn2/ly5JYDJz98Oyx8X0Weq4AuIk2j3IqGYf4w6N9zO/3xVYmVOwDHfxA+eCOYlXS/IBTQRaRp5MsuKRSYg/4wSA/6sTkRnIPxickDPwCOPRjWZK2z+ezv4ZA3hfCG+Smgi0jTKCe7JMgPg+yg/8qeyZlrR8cm6F+/CNZn3XhgvIg3KI8Cuog0jXmxKKM+wTs96yTXsEqQHwZ+QR/gr9vv5KrI7ZkHv7wbZs0u8U1Ko/K5ItI0ClU0zLeCM0h52+yg38E+dnUuzwjmayc/xNDS7VUP5qAeuog0kfTsEr+JzXzDKtmphgCRNmPPvikWrNzIvFiUrmiEsYnEMMuuzuWe5/fEBwHoDjBmXwkK6CLSVPJVNMw3rJL9w6ArGuH1fVMz4+SjYxNE2o1rIzdxWfvPMj7/H+M38CoHz3w/OjZB35rhqqdOKqCLSMsoNMae/sOgb83wTG8cwNjPv0Uuyfjc5v3HctG+r3juZzDznEpsBp2LxtBFpGUUs2tQem9+V+dydnZmBnMGxnn+vDs89zO8G0ZUozAXKKCLSAspZtegebEon26/wzNWvmz2d2ZSEf3ul6sqYaULc4GGXESkxQTdNWhT/LwDqzyTjp2+jdVnLc57v741wwVTJytFAV1EJN1Al+fQgvgg82JRVgeY3PTLlgl7M+hcFNBFRAAeuBE2XpV57JI74e3vYWcRtymUOllJCugiIj698nKW7Fd6M+hcFNBFpHWFHMhrTQFdRBpCqHXOnxyGH56Xceix467m+Au9OeU1bWeRFNBFpO6VW+c8g0+vvCc+SPSRdla/fbSs4BtqO0ugPHQRqXv5arAENtDlCeY98Vtm6q+EsfgnlHaWQQFdROpeOXXOef4xTyD/5fSSZCDP3D2o3MU/ZbUzBAroIlL3gpS29TXQBdf3ZR0bZ9VBq3wvL3fxT8ntDIkCuojUvWJqsAC+wyt8efdMBkvR96tUO0OmSVERqXuBF+uMPwPXHZdx6FUOYnjpA/SnbThRqcU/tVxUBGDO5SolU1m9vb1uZGSkJs8WkSaUI3sFEr3kXEW4Go2ZbXHO9fqdUw9dROpS4Hxun0DeG/8uL3LgePZmz81KAV1E6k6gfO59e+Dvj/B8dkF80LeEbbUyTWpJk6IiUncK5nMPdHmD+cA4DIzXPNOklhTQRaTu5OpN3zVxqXeI5aP3ZdRfqXWmSS1pyEVEfFW7JsnQ1lGu+cm2mU2ZMzl2dX7Ye9inkFatM01qSQFdRDyqXZNkaOsoK25/mMlp7+h39hZwcGDDiRVb/Wuv1Kp8ba1pyEVEPKpdk2TtPY97gvlds7/oCeaf3/9JepKTnqkfMkNbRyvSpkakHrqIeFS7Jkn2ff165X2d6zx7dbZKOmJQ6qGLiEeujJDYnAh9a4ZZsHIjfWuGQ+sdp563q3O5J5j3da6DgfGaF75qBAroIuLhlykSaTdei08xOjYR+pDHj958qyeQ3zp1Ggv33TqTndLK6YhBKaCLiEf/km5Wn7+Y7lgUA7pjUQ6aPYvJ/Znj3KGMqw90sWDXjzMO9cQH+VrkE6y98J0zwymtnI4YlMbQRcRXdqbIgpUbfa8recgjz36eu3K0B1ozHTEoBXQRCWReLOqZlEwdL8oDN8LGqzKPzTkMPvdUwY+2ajpiUBpyEZFAQhnyGOjyBPOhpdsDBXMpTD10EQmkrCEPn+GVt8Z/xH7aiFZxE+VmVzCgm9mRwM3AmwAH3OCc+2bWNe8G1gM7k4fudM5dG2pLRaTmih7y2LUJbjrbczhVpxyUSx6mID30KeCzzrkHzewQYIuZ/dw5tz3rul875z4QfhNFpF4Ere8ytHWU/vWLPMfTA3k65ZKHo2BAd849BzyX/PpPZrYD6AayA7qINLHA9V0GuujP+uwJ8e8xxiE5761c8nAUNSlqZj3AEuB+n9PvMrOHzeynZnacz3nM7EozGzGzkd27dxffWhGpmVz1Xa75ybbEN+PP5NwGLl8wVy55eAJPiprZwcAdwGecc69mnX4QOMo595qZnQ0MAQuz7+GcuwG4ARJ7ipbaaBGpvlzDIq/smcy7n2c+3colD1WggG5mERLB/Bbn3J3Z59MDvHPubjP7JzOb65x7Mbymikgt+eWh+xXR+u+z/5HhV+cVvF93LMqmlaeH1j4JMORiZgbcCOxwzn0jxzVvTl6HmZ2UvO9LYTZURKpjaOuobwGu9GGR2Uz6BnMGxjn3rLM9+ep+NMwSviA99D7gUuBRM3soeeyLwHwA59z1wAXAx81sCpgALnLOaUhFpMEUmvgc2LCNh9yFns/1da6b6W2n56v7rSwFaDPlnVeC1Sru9vb2upGRkZo8W0T89a0Z9g3C3bEom+LneY5/bvIKftL2Xlafv9g3QPfkqP8CsGvNOeU1tkWZ2RbnXK/fOa0UFZEZuSY+/YJ5ahu41XkmNQ+dE/HdI/TQOZHyGiq+FNBF6lC1N2hOyZ74zDVODgeWheeTawBAA7KVoeJcInUmNY5diY0kCkkV4Hqq48PeYL6ofyaYBzU+4e2d5zsu5VFAF6kz1d6gOV3/km52tC+jzbK60APj8KEfFH0/7TJUXQroInWmZntnDnR5Fgj1da5LlLctkXYZqi6NoYvUmdA2kgjqjivg0cwt4CZdOwv3/hDiOeq1BKRdhqpLAV2kzqw44+iMXHAIv1ebmnT1y17JXrJfbnlb7TJUPRpyEakzfhs058rzLkWqtG12MB86dxsLVN62oamHLlKHKtarfeCf6d/4Wc/hnvgg3T/7ffWHeyRUCugiraJARcRnxya4btkJFR/ukcpRQBdpdj6B/Jj4/yFOR8axebGoJjEbnAK6SIMoevXo0w/Aje/13mfpduzORyFHL1yTmI1LAV2kAQTe/i3Fp1eeWuXZn/xWvfDmo4Au0gDyrR7N3s/T41MPwmFvyzikXnhzUkAXaQAFV4++/iKsfZv3giJrr0hjU0AXqVPpY+ZtZkz7lCicF4vmHV6R1qKALlKHssfM/YL5rs7lEM86+KGbYdHSKrRQ6pECukgd8hszB2g3AzfNk52XeD+U1SuvVU11qR0FdJE6lGvM/MmOi70HfYZXis6KkaaggC4SgrB7w4F2Dur7NLzvWt/PB86Kkaai4lwiZarEDkPpdcRzbgOXI5hDDWuqS02phy5Spkr0hvuXdNO/fhG0Z51IG17J91uBimy1JvXQRcoUem94zXxvKuJhCz3BPN9vBdopqDWphy5SpnJ7w+k97Z25hleyFPqtQEW2WpMCukiZytlhKNXT3tG+DDqzTuZZHBTktwIt7289GnIRKVNqh6FD50RmjnXMCvaf1tRdVyeCeZa+znV5P5er968x8tamgC4Skvjk/pmvxyYmC2e6DHRxwfTdGYd64oP0xAcLjr9rjFz8aMhFJARFZbr41F55a/xH7E/rXxXqaWuMXPwooIuEIFCmy/b18OPLPNccO30b+znww8BIZK30rRnOG6Q1Ri7ZFNBFQlAw0yVPRcTVySyX0bEJDEiV4dJyfSmWxtBFQpBrTHtT/DxvMF/xVEYGS/+SbjatPJ3uWJTsmoqpYRuRIBTQRUKQynTpjkUx4KQ3jPlmrzAwDgcd5nsPLdeXcmnIRSQkM2PaA12wL+tkgA0ntFxfyqUeukhYBrq8wytXDAfePUipiFIu9dBFyrXvdfj7ed7jRW4Dp1REKZcCukg5CuznWWyddKUiSjkU0EWKkArQm+LneU++/2/hP38q49pidg3SlnFSLgV0kYASAfoRdrRf5D1ZQkVE7721ZZyUR5OiIgH1r1/kCeY98cGchbSKSUPMF/xFgioY0M3sSDO718y2m9k2M/u0zzVmZt8ysyfM7BEzO7EyzRWpAZ/slW37j6InPgjgm2oIxVVEVA66hCFID30K+KxzbhFwCvBJM1uUdc1ZwMLknyuB74baSpFa8Zn07IkPcs6+1TPft5v5frSYNESVw5UwFBxDd849BzyX/PpPZrYD6Aa2p122FLjZOeeAzWYWM7Mjkp8VqRuBJx5zBHI/0y57wX5CMWmIpx1zOD/a/Aff4yJBFTUpamY9wBLg/qxT3cDTad8/kzymgC51I9DE43XHw/jT3g8PjNO9Zth3eKU7Ty86aBrivb/bXdRxET+BJ0XN7GDgDuAzzrlXS3mYmV1pZiNmNrJ7t/6hSnUVnHgc6PIG84HxmQyWSq7k1Bi6hCFQD93MIiSC+S3OuTt9LhkFjkz7/i3JYxmcczcANwD09vb6/54qUiG5gmOiImLWwRyLg2JzInTMamN8YjLUXHHVcZEwBMlyMeBGYIdz7hs5LtsAXJbMdjkFGNf4udSb7OB49azb2NW53HthVjD/wp2PMjo2gQNe2TPJ3qn9XLfsBDatPD20HHHVcZEwBOmh9wGXAo+a2UPJY18E5gM4564H7gbOBp4A9gB/GXpLRcq04oyjZ8bQCwXylKK2liuD6rhIGIJkufw/Erti5bvGAZ8Mq1EildC/pJv+9YugPevEl56HiP/QRjXHtlXHRcqlpf/SFAqmIz7xC/jRB70fLFARUWPb0kgU0KXhFUxHLFARMf0+2T8U0odp0u3ZN8XQ1lH1qKWumMuxKKLSent73cjISE2eLY0vPfi2mfku7vEdJ/+rEZi70Pd+2YE7Gmln9fmLARjYsI2xicmMz6TOK6hLNZnZFudcr985FeeShpOdeZIdzN/Ey7knPX2CORSe/Dyow/vLrIpnSb3RkIs0HL/gmxI0eyVboclPLfyRRqAeujQcvyC6q3O5J5g/8Gf/GHgbuNicSN7jKp4ljUABXRpOehBtxz+nfGjpdv7Tn18R+J65ppJSx7XwRxqBhlyk4aQyT3a0L/OeTPbI+4u853jWhGf2cS38kUaggC4Np3/9IvqzFgf9eOpUvnnwZ1hRYiphkHxzLfyReqeALo0lX53yMvbh9Ms315CKNBoFdClZVXep9wnkfZ3rPL3qUuusaEhFmoECupSkarvU+63ybJsFq17i2ZUbfT9SaiqhhlSk0SnLRUpSlV3qcy3ZX/USoFRCkWzqoUtJKrrQJmDtFY17i2RSQJdAssfLY3MivLLHm+pXVu/4lgvh337mPZ5jcVChce+qjvGL1AEFdCnIb7w80mZE2o3J6QMrcorpHWcH203x87wXBVjlmWvcu2pj/CJ1RAFdCvIbL5/c74hFIxzUMatgDzg7eJ92zOHcsWX0wM5B8awPrHoF2sqb3qnWTkMi9UQBXQrKNS4+PjHJQ199f97P+vWUb9n8By5rv4drOn/g/UDA2iuFqJiWtCIFdCko1yrKrmiEvjXDeXvofj3lnT61V3rigxiws8JtVgaMNDOlLUpBfoWpIm3G6/umZmqSp8aoh7aOZlyX3iP2q4h4Yvz6mZWeYQZbFdOSVqSALgX1L+lm9fmL6Y5FMaA7FuXgzlkZE6Lgn4c+LxZlsT3lWxGxJz7Iy7wBSPyA2LNvigUrN9K3ZtjzgyGMNmt3IWl22oJOSrJg5Ub8/uUYsHPNOQcO5Ku9khSLRnh935QnY0YBWMRLW9BJ6Aqu0hzo8gTzZXu/4gnmBhzUEay3LyL5KaBLSXKNUX/h9Hk5C2nd7471HJ8XiyojRSQkCuhSEr8x6h3ty/jA3SdnXjgwDgPjeScpVZNFJBxKW5SSzazSHOjyLg467Utw6ucyroXcy/RVk0WkfAroUjrn4JqY93ie2it+k5yqRS4SDgV0KU3AiohBqRa5SPkU0KU4foH8jYvgE7+pfltEJIMCuuSVXljLb8l+WLVXRKR8CugtopTa4KnCWjval0Fn1kkFcpG6o4BeRyq1IUOh2uDpz43NieBcopLiAx0fY0f7q5779XWuY1ON30lEvBTQ60QlN2QotP9n+nNTuxDlqr0CYPFgC360yYRIdWlhUZ2o5KbL+VZiZj/XryJiT3wwY8l+0AU/VdlIWkRmKKDXiUouf8+3EjN1/y/MuiVvrzylmAU/WtIvUl0K6HWiksvfCy2739W5nI/O2phxPr1X3m5WUglaLekXqS4F9DpRyQ0ZctYGX7/IsznzO+I/yOiVR9qNQzpzT7UMbR2lb82wbx1zbTIhUl2qh15HqpYR8uQw/PA8z+EF8cGMLJfYnAivxaeY3J/5b+TQORG++ufHAf41WNJ78cpyEQlXvnroCuitpogl+31rhn335YRE4O6Y1cbYxKTnXHcsyqaVp5fVTBHxV9YGF2b2fTN7wcwey3H+3WY2bmYPJf+sKrfBUgE+G06c3/E9FsQHc275lm/ycmJy2jeYF/qciFROkDz0m4BvAzfnuebXzrkPhNIiCderz8E3jvEcPnb6NibG8+eHz4tFc/bQ89Gkp0htFOyhO+fuA16uQlskbANd3mA+ME5f57pA+eF+k5rpDp0T0aSnSB0Ja6Xou8zsYeBZ4Grn3LaQ7itZAk0y+o2TX7oO3pYY1w6aH56678CGbZ7hlWikfWZiVJOeIvUhjID+IHCUc+41MzsbGAIW+l1oZlcCVwLMnz8/hEc3l0LBuuBS+ulJ+Ju53htnTXrmGkrxGypJ1SnP1zYFcJH6ECjLxcx6gLucc8cHuHYX0OucezHfdcpyyZQdrMGbApgr66Q7FvXkkwM5s1eCPEtE6lO+LJeye+hm9mbgj845Z2YnkRiXf6nc+9aLauVR56p78pnbHmLtPY+z4oyjfYdKdnUu9+7neern4bQv5nyWtnwTaU4FA7qZ3Qq8G5hrZs8AXwUiAM6564ELgI+b2RQwAVzkapXcHrJqVgvMl+qXem5XNJIxlu1XeyVonfJcW75pIZBI4yoY0J1zFxc4/20SaY1NJ1+1wLCDXKEUwYnJaTojbUQj7YkNJ7KFsOGEyt2KNDbVcsmjmtUCC6UIAvx6+lJPMH8ldnxouwep3K1IY9MGF3kUkw1SrvRxbb9n5hpeOTTENqjcrUhjUw89j2pXC+xf0s2mlafzP5edMPNcvw0nGBivyJ6eKncr0tgU0PPIWXa2wuPJ/Uu6uafn1rImPUuhcrcijU1DLgXkygYJW3p2yc7O5XiWXVUwkKconVGksal8bg2lgvjo2AQG7PTpkQ+du421P/u9AqyIABVeWCTBZOd3n3bM4dyxZZSJyWnObtvMP83+luczJ9i/snfdY0ojFJFAFNCrwC+/+5bNf8Dhn71yYAs4b73xSuXBi0jjU0CvAr/8br/hlePiN/I6hTNKlEYoIn4U0MsQdJl8egB+m43yy44VnmvSN2aGRHZJZ6SNV/Z4e+lKIxQRPwroJcq3TB4yM0VSNVjyD6+AAY5EemQqVdCvKqLSCEXEjwJ6EdJ75G1mTGdlCE1MTnPNT7YRn9yfEeh3dS6Hzsx7nTP9dZb82cl0/253wR6+0ghFJAgF9ICye+TZwTwlfYjkICbY1nm555q37b2Vi08+kr/tX1zwudXKgxeRxqeAHpDfxGY++YdXHHdsGaX3qP+gYC0ioWm5gF5qve8gmSW5Stt+fN+n+en+kzOOBUk/VG1yESlGSwX0cup956q82G7GfueY19XJpr3ne85nZ6+ky/dDQrXJRaRYLVWcq5x637kKV339Q+9kZ+dyTzDv61zHgvgg3bEosWjE95750g9Vm1xEitVSPfRy6n37Fa7a2PElYuu3Z1y345i/4tiL/o5NacdybcqcL/1QtclFpFgtFdDL3bAiI+NkoMuzOXNPfJDotnZWbx3NGBYppYphNTfXEJHm0FIBfcUZR5e/UGegy3MofZw812RnsemHobRVRFpKS42hl7Vhxc9XeYL5v0yd5jvpGcawSK021xCRxtVSPXTwDn+kJhnzBkqfXjkD4/yvNcNQwWERLSoSkWI0XUAvlLtdVDpgjkCeomEREaknTbVjUa5skvShir41w76TjSntZqw6dpSPPHV15ok3LoJP/Mb3mVr8IyLV0jI7FuXL3U4F2ULj2092XAxPZR302c8zO5Bft+wEBXIRqammCuhBcrdzpQP61V5h1cvQ1u45rFWcIlKPmirLJddkZPrxFWccTaTdZr5/i72Qu5CWTzAHreIUkfrU8D309KGPrmiESLsxOX1gXiDSbry+d4oFKzfOjHEfNHtWwQ0n2s0851K0ilNE6lFDB/TsoY+xiUkibcahcyKM7ZkkNifCa/EpxiYSNcpTQyM/tb+mp/OPGfdaFP8+e9J2obj45CNzPrfQKk5NlIpILTT0kIvf0Mfkfsec2bPYueYc5syexeT+A731Tvayo30ZPW2ZwbwnPjgTzNvNuOSU+Xk3n8hVqGvFGUfP/JAZHZvAceCHyNDW0TLfVkQkv4buoRca+kg/X2g/z+z0xnzy1WbpWzNcMNNGRKQSGjqgFxr6mBeL0venu/mHyP/OOH9Ox01cceZJdJcxLJJrFafG10WkVho6oPut1ATYs2+KoQefZlP8PEgrRf6Ci3Hq/u+x+szFFVtWryqJIlIrDR3QUwF5YMO2mYlPgF9Nf4SuDXsyrl0QH2ReLMrqAD3xciY1VQ5ARGqloQJ6rkC79p7HGZuY5AR7gqGOVZkf+uzjcMib2VnEM8pZNFRK7XMRkTA0TEDPF2ifHZvwTHr+89RZ/N3Upew85M1FPSdI+YBCVCVRRGqhYQJ6rkD7i7vvYGfnVzKOp7JXuksYt9akpog0qoYJ6NkB9WD28NuOTzJncu/MsRPj1/MybwBKH7fWpKaINKqGWViUHlBXzrqVxzr/B3MsGcwv/zlDS7cTjb2p7N198i0aEhGpZw3TQ09ljxw1tZOPzfoJADftP4fYeWvpP7Kb/iPDqXSoSU0RaVQFA7qZfR/4APCCc+54n/MGfBM4G9gD/IVz7sGwG5oKqN/4v+1c/dpHeeTg/8onzjyxIimImtQUkUYUpId+E/Bt4OYc588CFib/nAx8N/m/oTsQaN8X6HrVLReRVlJwDN05dx/wcp5LlgI3u4TNQMzMjgirgeVQ3XIRaSVhTIp2A0+nff9M8piHmV1pZiNmNrJ79+4QHp2fUhBFpJVUNcvFOXeDc67XOdd7+OGHV/x5QXYwEhFpFmEE9FEgfTeItySP1ZxSEEWklYQR0DcAl1nCKcC4c+65EO5btv4l3aw+fzHdsWjZ+ekiIvUuSNrircC7gblm9gzwVZJFaZ1z1wN3k0hZfIJE2uJfVqqxpVAKooi0ioIB3Tl3cYHzDvhkaC0SEZGSNMzSfxERyU8BXUSkSSigi4g0CQV0EZEmYYk5zRo82Gw38O8lfHQu8GLIzWkEeu/WovduLcW891HOOd+VmTUL6KUysxHnXG+t21Fteu/WovduLWG9t4ZcRESahAK6iEiTaMSAfkOtG1Ajeu/WovduLaG8d8ONoYuIiL9G7KGLiIgPBXQRkSZRlwHdzM40s8fN7AkzW+lzvsPMbkuev9/MemrQzNAFeO+rzGy7mT1iZr80s6Nq0c5KKPTuadd90MycmTV8aluQdzazDyX/zreZ2WC121gpAf6tzzeze81sa/Lf+9m1aGeYzOz7ZvaCmT2W47yZ2beS/588YmYnFv0Q51xd/QHagSeBtwKzgYeBRVnXfAK4Pvn1RcBttW53ld77NGBO8uuPN8N7B3335HWHAPcBm4HeWre7Cn/fC4GtwKHJ799Y63ZX8d1vAD6e/HoRsKvW7Q7hvf8bcCLwWI7zZwM/BQw4Bbi/2GfUYw/9JOAJ59xTzrl9wL+Q2Ig63VLgB8mvbwfeY2ZWxTZWQsH3ds7d65zbk/x2M4ndoZpBkL9zgL8BvgbEq9m4CgnyzlcA33HOvQLgnHuhym2slCDv7oA3JL/uAp6tYvsqwjl3H/BynkuWAje7hM1AzMyOKOYZ9RjQg2w6PXONc24KGAcOq0rrKifwZttJl5P4ad4MCr578tfPI51zG6vZsAoK8vf9DuAdZrbJzDab2ZlVa11lBXn3AeCS5KY6dwOfqk7TaqrYGOBRcIMLqT9mdgnQC5xa67ZUg5m1Ad8A/qLGTam2WSSGXd5N4rex+8xssXNurJaNqpKLgZucc183s3cBPzSz451z+2vdsHpWjz30IJtOz1xjZrNI/Er2UlVaVzmBNts2s/cCXwLOdc7trVLbKq3Qux8CHA/8ysx2kRhf3NDgE6NB/r6fATY45yadczuB35MI8I0uyLtfDvwYwDn3G6CTRAGrZhYoBuRTjwH9AWChmS0ws9kkJj03ZF2zAfhI8usLgGGXnFVoYAXf28yWAN8jEcybZTwVCry7c27cOTfXOdfjnOshMX9wrnNupDbNDUWQf+dDJHrnmNlcEkMwT1WxjZUS5N3/ALwHwMyOJRHQd1e1ldW3Abgsme1yCjDunHuuqDvUeuY3z2zv70nMhH8peexaEv8RQ+Iv919JbEz9W+CttW5zld77F8AfgYeSfzbUus3Vevesa39Fg2e5BPz7NhJDTduBR4GLat3mKr77ImATiQyYh4D317rNIbzzrcBzwCSJ374uBz4GfCzt7/s7yf9PHi3l37iW/ouINIl6HHIREZESKKCLiDQJBXQRkSahgC4i0iQU0EVEmoQCuohIk1BAFxFpEv8fXucfQGdcekcAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.plot(x, y, 'o')\n",
    "plt.plot(x_tensor.numpy(), y_pred.detach().numpy())\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6673f9a",
   "metadata": {},
   "source": [
    "### 完整代码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b0d2ed6a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "w: tensor([1.9540], requires_grad=True)\n",
      "b: tensor([1.0215], requires_grad=True)\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "# 设置随机数种子，使得每次运行代码生成的数据相同\n",
    "np.random.seed(42)\n",
    "\n",
    "# 生成随机数据，w为2，b为1\n",
    "x = np.random.rand(100, 1)\n",
    "y = 1 + 2 * x + 0.1 * np.random.randn(100, 1)\n",
    "\n",
    "# 将数据转换为 pytorch tensor\n",
    "x_tensor = torch.from_numpy(x).float()\n",
    "y_tensor = torch.from_numpy(y).float()\n",
    "\n",
    "# 设置超参数\n",
    "learning_rate = 0.1\n",
    "num_epochs = 1000\n",
    "\n",
    "# 初始化参数，可以使用常数、随机数或预训练等\n",
    "w = torch.randn(1, requires_grad=True)\n",
    "b = torch.zeros(1, requires_grad=True)\n",
    "\n",
    "# 开始训练\n",
    "for epoch in range(num_epochs):\n",
    "    # 计算预测值\n",
    "    y_pred = x_tensor * w + b\n",
    "\n",
    "    # 计算损失\n",
    "    loss = ((y_pred - y_tensor) ** 2).mean()\n",
    "\n",
    "    # 反向传播\n",
    "    loss.backward()\n",
    "\n",
    "    # 更新参数\n",
    "    with torch.no_grad():\n",
    "        w -= learning_rate * w.grad\n",
    "        b -= learning_rate * b.grad\n",
    "\n",
    "        # 清空梯度\n",
    "        w.grad.zero_()\n",
    "        b.grad.zero_()\n",
    "\n",
    "# 输出训练后的参数，与数据生成时设置的常数基本一致\n",
    "print('w:', w)\n",
    "print('b:', b)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7b633db9",
   "metadata": {},
   "source": [
    "### Pytorch模型实现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "85684228",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "w: tensor([[1.9540]])\n",
      "b: tensor([1.0215])\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "# 设置随机数种子，使得每次运行代码生成的数据相同\n",
    "np.random.seed(42)\n",
    "\n",
    "# 生成随机数据\n",
    "x = np.random.rand(100, 1)\n",
    "y = 1 + 2 * x + 0.1 * np.random.randn(100, 1)\n",
    "\n",
    "# 将数据转换为 pytorch tensor\n",
    "x_tensor = torch.from_numpy(x).float()\n",
    "y_tensor = torch.from_numpy(y).float()\n",
    "\n",
    "# 设置超参数\n",
    "learning_rate = 0.1\n",
    "num_epochs = 1000\n",
    "\n",
    "# 定义输入数据的维度和输出数据的维度\n",
    "input_dim = 1\n",
    "output_dim = 1\n",
    "\n",
    "# 定义模型，就是一个神经元\n",
    "model = nn.Linear(input_dim, output_dim)\n",
    "\n",
    "# 定义损失函数和优化器\n",
    "criterion = nn.MSELoss()\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)\n",
    "\n",
    "# 开始训练\n",
    "for epoch in range(num_epochs):\n",
    "    # 将输入数据喂给模型\n",
    "    y_pred = model(x_tensor)\n",
    "\n",
    "    # 计算损失\n",
    "    loss = criterion(y_pred, y_tensor)\n",
    "    \n",
    "    # 清空梯度\n",
    "    optimizer.zero_grad()\n",
    "\n",
    "    # 反向传播\n",
    "    loss.backward()\n",
    "\n",
    "    # 更新参数\n",
    "    optimizer.step()\n",
    "\n",
    "# 输出训练后的参数\n",
    "print('w:', model.weight.data)\n",
    "print('b:', model.bias.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37131c58",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {
    "height": "68px",
    "width": "172px"
   },
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {
    "height": "calc(100% - 180px)",
    "left": "10px",
    "top": "150px",
    "width": "180.6px"
   },
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
